{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import IPython.display as ipd\n",
    "\n",
    "import sys\n",
    "sys.path.append('./waveglow/')\n",
    "\n",
    "import librosa\n",
    "import numpy as np\n",
    "import os\n",
    "import glob\n",
    "import json\n",
    "\n",
    "import torch\n",
    "from text import text_to_sequence, cmudict\n",
    "from text.symbols import symbols\n",
    "import commons\n",
    "import attentions\n",
    "import modules\n",
    "import models\n",
    "import utils\n",
    "\n",
    "# load WaveGlow\n",
    "waveglow_path = './waveglow/waveglow_256channels_ljs_v3.pt' # or change to the latest version of the pretrained WaveGlow.\n",
    "waveglow = torch.load(waveglow_path)['model']\n",
    "waveglow = waveglow.remove_weightnorm(waveglow)\n",
    "_ = waveglow.cuda().eval()\n",
    "from apex import amp\n",
    "waveglow, _ = amp.initialize(waveglow, [], opt_level=\"O3\") # Try if you want to boost up synthesis speed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# If you are using your own trained model\n",
    "model_dir = \"./logs/your_dir/\"\n",
    "hps = utils.get_hparams_from_dir(model_dir)\n",
    "checkpoint_path = utils.latest_checkpoint_path(model_dir)\n",
    "\n",
    "# If you are using a provided pretrained model\n",
    "# hps = utils.get_hparams_from_file(\"./configs/any_config_file.json\")\n",
    "# checkpoint_path = \"/path/to/pretrained_model\"\n",
    "\n",
    "model = models.FlowGenerator(\n",
    "    len(symbols) + getattr(hps.data, \"add_blank\", False),\n",
    "    out_channels=hps.data.n_mel_channels,\n",
    "    **hps.model).to(\"cuda\")\n",
    "\n",
    "utils.load_checkpoint(checkpoint_path, model)\n",
    "model.decoder.store_inverse() # do not calcuate jacobians for fast decoding\n",
    "_ = model.eval()\n",
    "\n",
    "cmu_dict = cmudict.CMUDict(hps.data.cmudict_path)\n",
    "\n",
    "# normalizing & type casting\n",
    "def normalize_audio(x, max_wav_value=hps.data.max_wav_value):\n",
    "    return np.clip((x / np.abs(x).max()) * max_wav_value, -32768, 32767).astype(\"int16\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst_stn = \"Glow TTS is really awesome !\" \n",
    "\n",
    "if getattr(hps.data, \"add_blank\", False):\n",
    "    text_norm = text_to_sequence(tst_stn.strip(), ['english_cleaners'], cmu_dict)\n",
    "    text_norm = commons.intersperse(text_norm, len(symbols))\n",
    "else: # If not using \"add_blank\" option during training, adding spaces at the beginning and the end of utterance improves quality\n",
    "    tst_stn = \" \" + tst_stn.strip() + \" \"\n",
    "    text_norm = text_to_sequence(tst_stn.strip(), ['english_cleaners'], cmu_dict)\n",
    "sequence = np.array(text_norm)[None, :]\n",
    "print(\"\".join([symbols[c] if c < len(symbols) else \"<BNK>\" for c in sequence[0]]))\n",
    "x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()\n",
    "x_tst_lengths = torch.tensor([x_tst.shape[1]]).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  noise_scale = .667\n",
    "  length_scale = 1.0\n",
    "  (y_gen_tst, *_), *_, (attn_gen, *_) = model(x_tst, x_tst_lengths, gen=True, noise_scale=noise_scale, length_scale=length_scale)\n",
    "  try:\n",
    "    audio = waveglow.infer(y_gen_tst.half(), sigma=.666)\n",
    "  except:\n",
    "    audio = waveglow.infer(y_gen_tst, sigma=.666)\n",
    "ipd.Audio(normalize_audio(audio[0].clamp(-1,1).data.cpu().float().numpy()), rate=hps.data.sampling_rate)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
